import torch
from transformers import BertTokenizer, BertModel
from scipy.spatial.distance import cosine
import json
from tqdm import tqdm
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torchmetrics.functional.multimodal import clip_score
from functools import partial

clip_model_path = 'pretrained/clip-vit-large-patch14-336'
clip_score_fn = partial(clip_score, model_name_or_path=clip_model_path)

bert_model_path = 'pretrained/bert-large-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
bert_model = BertModel.from_pretrained(bert_model_path)

def get_bert_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = bert_model(**inputs)
    return outputs['last_hidden_state'][:,0,:].numpy()

def cosine_similarity_bert(s1, s2):
    emb1 = get_bert_embedding(s1)
    emb2 = get_bert_embedding(s2)
    emb1 = emb1.flatten()
    emb2 = emb2.flatten()
    similarity = 1 - cosine(emb1, emb2)
    return similarity

def calculate_clip_score(images, prompts):
    # import pdb;pdb.set_trace()
    # images_int = (np.asarray(images[0]) * 255).astype("uint8")
    images_int = (np.asarray(images) * 255).astype("uint8")
    if images_int.ndim == 3:  # [H, W, C]
        images_int = np.expand_dims(images_int, axis=0)  # [1, H, W, C]
    clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()
    return round(float(clip_score), 4)


with open('datasets/journydb/caption/caption_data_unidop.json', "r") as f:
    data = json.load(f)

with open('datasets/journydb/dpo_data.json', "r") as f:
    initial_dpo_data = json.load(f)    
    

for i in tqdm(range(len(data))):
    gt_caption = data[i]["caption"]
    clip_score = []
    bert_score = []
    for j in range(16):
        img_path = "datasets/journydb/generated_images_unidpo/" + str(data[i]["id"]) + "/" + str(j) + ".png"
        img = Image.open(img_path)
        clip_score.append(calculate_clip_score(img, gt_caption))
    max_clip_score = max(clip_score)
    max_clip_score_index = clip_score.index(max_clip_score)
    if max_clip_score > initial_dpo_data[i]["clip_score_win"]:
        image_win_path = "datasets/journydb/generated_images_unidpo/" + str(data[i]["id"]) + "/" + str(max_clip_score_index) + ".png"
        image_lose_path = initial_dpo_data[i]["image_win"]
        clip_score_win = max_clip_score
        clip_score_lose = initial_dpo_data[i]["clip_score_win"]
    else:
        image_win_path = "datasets/journydb/generated_images_unidpo/" + str(data[i]["id"]) + "/" + str(max_clip_score_index) + ".png"
        image_lose_path = initial_dpo_data[i]["image_lose"]
        clip_score_win = max_clip_score
        clip_score_lose = initial_dpo_data[i]["clip_score_lose"]

    for k in range(len(data[i]["caption_mmu"])):
        bert_score.append(cosine_similarity_bert(gt_caption, data[i]["caption_mmu"][k]))
    max_bert_score = max(bert_score)
    max_bert_score_index = bert_score.index(max_bert_score)
    if max_bert_score > initial_dpo_data[i]["bert_score_win"]:
        caption_win = data[i]["caption_mmu"][max_bert_score_index]
        caption_lose = initial_dpo_data[i]["caption_win"]
        bert_score_win = max_bert_score
        bert_score_lose = initial_dpo_data[i]["bert_score_win"]
    else:
        caption_win = data[i]["caption_mmu"][max_bert_score_index]
        caption_lose = initial_dpo_data[i]["caption_lose"]
        bert_score_win = max_bert_score
        bert_score_lose = initial_dpo_data[i]["bert_score_lose"]
    with open(f"datasets/journydb/dpo_data_iteration1.json", "a") as f:
        json.dump({
            "id": data[i]["id"],
            "img_path": data[i]["img_path"],
            "prompt": data[i]["prompt"],
            "caption": data[i]["caption"],
            "caption_win": caption_win,
            "caption_lose": caption_lose,
            "bert_score_win": bert_score_win,
            "bert_score_lose": bert_score_lose,
            "image_win": image_win_path,
            "image_lose": image_lose_path,
            "clip_score_win": clip_score_win,
            "clip_score_lose": clip_score_lose
        }, f, ensure_ascii=False, indent=4)
        f.write(',\n')